昨天已經把 RoPE 觀念介紹完了,會發現數學公式比絕對位置編碼複雜一點,但直接應用在 Q 跟 K 而不是與詞量相加,這樣子更有效率。
參考文章:
https://www.cnblogs.com/rossiXYZ/p/18787343
先複習一下之前位置編碼有出現的圖,當初實作時透過公式事先將表格算好,用於後續計算,那這裡可以整理成圖上的第二步。
這裡 RoPE 我們也會事先計算 cos(mθ), sin(mθ),儲存起來,然後再與 query, key 做相乘。
這裡實作會有兩種版本,以下程式用方法一來實作,比較能跟公式有對應。
會實作三個 function 分別是:
上面的 m 跟我們之前用的 pos 是一樣的,代表長度,我們先看一下需要儲存哪些,跟之前一樣假設我們輸入長度為 4, d 的維度為 6,那會像下面,x 的部分會是 x_1 ~ x_6,m = 0 時 x 乘上第一個 cos 矩陣,m = 1 時乘上第二個矩陣一直到 m = 3 (輸入長度 - 1),所以假設需要 4 * 6 = 24 個位置儲存
當中 𝜃 的計算跟之前一樣,然後我們使用 torch outer 來計算剛才的m𝜃,可以想像 a[i] 就是 m, b[j] 就是 𝜃
圖片連結: https://zhuanlan.zhihu.com/p/714192908
outer 計算完會長的像下面這樣
步驟如下:
import torch
def precompute_freqs_cis(hidden_size, max_seq_len):
# step 1, 2 與之前計算一樣
inv_freq = 1 / (10000 ** (torch.arange(0, hidden_size, 2).float() / hidden_size))
m = torch.arange(max_seq_len)
# step 3: 使用 torch.outer 來計算出 m𝜃 的部分
freqs = torch.outer(m, inv_freq).float()
print(f'm shape: {m.shape}, inv_freq shape: {inv_freq.shape}')
print(f'freqs shape: {freqs.shape}')
# step 4: 創建 cos 矩陣, 以及 sin 矩陣
freqs_cos = torch.repeat_interleave(torch.cos(freqs), repeats=2, dim=-1)
freqs_sin = torch.repeat_interleave(torch.cos(freqs), repeats=2, dim=-1)
print(f'repeat 之前:\n {torch.cos(freqs)}')
print(f'repeat 之後:\n {freqs_cos}')
return freqs_cos, freqs_sin
if __name__ == "__main__":
precompute_freqs_cis(6, 4)
在這邊先停一下,讓我們先看看下面,會發現公式的順序不太一樣,其中一個與論文一樣,一個是 huggingface 的實作,那有研究發現只要維度是偶數,那不管用下面哪一種,最終 attention 內積結果都能感知到相對位置訊息,那這邊我們照原先論文的來實作。
現在來實作 rotate_half,也就是將 x 的部分做旋轉,以下給出三種常看到的方式,有照論文也有照 huggingface,可以先猜猜看。
import torch
def rotate_half_v1(x: torch.Tensor):
x1, x2 = x.chunk(2, dim = -1)
return torch.cat((-x2, x1), dim = -1)
def rotate_half_v2(x):
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def rotate_half_v3(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
if __name__ == "__main__":
x = torch.tensor([1., 2., 3., 4., 5., 6.])
v1 = rotate_half_v1(x)
v2 = rotate_half_v2(x)
v3 = rotate_half_v3(x)
print(x)
print(v1, v2, v3)
會發現 v2 是論文的版本。
那麼這裡稍微提一下, repeat_interleave 會搭配 v2,另外 v1, v3 會搭配 torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
接下來就是如何將位置資訊融入 q, k 得到 q_rope, k_rope。
步驟如下:
import torch
def apply_rope(q, k, cos: torch.Tensor, sin: torch.Tensor):
'''
q, k: (B, L, n_head, head_dim)
cos, sin: (L, head_dim)
'''
# 確保 cos 和 sin 的維度與 q, k 匹配
cos = cos.unsqueeze(0).unsqueeze(2) # (1, L, 1, head_dim)
sin = sin.unsqueeze(0).unsqueeze(2)
q_rope = (q * cos) + (rotate_half_v2(q) * sin)
k_rope = (k * cos) + (rotate_half_v2(k) * sin)
return q_rope, k_rope
def precompute_freqs_cis(head_dim, max_seq_len):
# step 1, 2 與之前計算一樣
inv_freq = 1 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
m = torch.arange(max_seq_len)
# step 3: 使用 torch.outer 來計算出 m𝜃 的部分
freqs = torch.outer(m, inv_freq).float()
print(f'm shape: {m.shape}, inv_freq shape: {inv_freq.shape}')
print(f'freqs shape: {freqs.shape}')
# step 4: 創建 cos 矩陣, 以及 sin 矩陣
freqs_cos = torch.repeat_interleave(torch.cos(freqs), repeats=2, dim=-1)
freqs_sin = torch.repeat_interleave(torch.sin(freqs), repeats=2, dim=-1)
print(f'repeat 之前:\n {torch.cos(freqs)}')
print(f'repeat 之後:\n {freqs_cos}')
return freqs_cos, freqs_sin
def rotate_half_v1(x: torch.Tensor):
x1, x2 = x.chunk(2, dim = -1)
return torch.cat((-x2, x1), dim = -1)
def rotate_half_v2(x):
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def rotate_half_v3(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
if __name__ == "__main__":
batch_size = 1
seq_len = 4
max_seq_len = 10
n_head = 2
head_dim = 6
# 模擬 multi-head 的 q, k
q = torch.randn(batch_size, seq_len, n_head, head_dim)
k = torch.randn(batch_size, seq_len, n_head, head_dim)
# Step 1: 事先計算 freqs
freqs_cos, freqs_sin = precompute_freqs_cis(head_dim, max_seq_len)
print(f"freqs_cos shape: {freqs_cos.shape}")
print(f"freqs_sin shape: {freqs_sin.shape}")
# Step 2: 套用 RoPE
freqs_cos = freqs_cos[: seq_len] # 依照當前 len 做截斷
freqs_sin = freqs_sin[: seq_len]
q_rope, k_rope = apply_rope(q, k, freqs_cos, freqs_sin)
雖然才三個 block,不過需要花些時間理解,今天就到這裡囉~ 明天是完賽心得,所以到這裡實作就算最後一篇了。